# 该函数提供了有关模型训练的相关处理，主要功能有保存模型的“checkpoint”和有关模型的全部信息，并建立新文件保存“model summary” and some other
# global variables that have been imported，最后返回“history”
def model_train(model, ds_train, ds_test, model_type_index, training_result_dir):
    """"
    input:
    model: the model which was used for training
    ds_train: dataset for training
    ds_test: dataset for testing
    model_type_index: indicate which kind of model for training, 0 indicated for vision transformer, 1 indicated for CNN
    training_result_dir: the directory which was used for storing information
    return:
    history: model training history
    """
    import tensorflow as tf
    import sys
    import os
    import time

    from Global_variable_setting import training_dataset_length, test_dataset_length, size_of_batch
    from Global_variable_setting import optimizer_initial_lr, weight_decay, loss_values, metrics_values
    from Global_variable_setting import image_size, patch_size, num_patches
    from Global_variable_setting import projection_dim, num_heads, transformer_hidden_units
    from Global_variable_setting import transformer_layers, mlp_head_units
    from Global_variable_setting import dropout, emb_dropout

    # i = str(i)
    # path_prefix = os.getcwd() + '/result/'
    checkpoint_path = training_result_dir + '/training/cp-{epoch:04d}.ckpt'
    # checkpoint_path = path_prefix+i+'/training/cp-{epoch:04d}.ckpt'
    if not os.path.exists(training_result_dir+'/training'):                # 如果没有保存“checkpoint”的路径的话，先建造路径
        os.mkdir(training_result_dir)
        os.mkdir(training_result_dir+'/training')
    # cp_callback = tf.keras.callbacks.ModelCheckpoint(                # Building checkpoint callback
    #     filepath=checkpoint_path,
    #     verbose=1,
    #     save_weights_only=True
    # )
    csv_callback = tf.keras.callbacks.CSVLogger(filename=training_result_dir+'/CSVLogger.csv',  # Building csv_logger callback
                                                append=False)
    es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss',  # Building early_stopping callback
                                                   min_delta=1e-6,
                                                   patience=10,
                                                   verbose=1,
                                                   restore_best_weights=True)
    rlr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',   # Building reduce_lr callback
                                                        factor=0.1,
                                                        patience=7,
                                                        verbose=1,
                                                        mode='auto',
                                                        cooldown=0,
                                                        min_lr=1e-7)
    start_time = time.time()
    history = model.fit(x=ds_train,
                        validation_data=ds_test,
                        epochs=100000,
                        callbacks=[es_callback, rlr_callback, csv_callback]
                        )
    end_time = time.time()
    dur_time = end_time - start_time
    if not os.path.exists(training_result_dir+'/model'):                   # 如果没有保存“model”的路径的话，先建造路径
        os.mkdir(training_result_dir+'/model')
    model.save(training_result_dir+'/model')
    std_out = sys.stdout
    sys.stdout = open(training_result_dir+'/model_summary.txt', 'w')                                            # 标准输出重定向到目标文件
    model.summary()
    print()
    print()
    print(f'The training time of this model is {dur_time} s')
    print()
    print()
    print('(1) The parameters about dataset:')
    print('    Training dataset length:', training_dataset_length)
    print('    Test dataset length:', test_dataset_length)
    print('    Size of batch:', size_of_batch)
    print()
    print('(2) The configuration of the model compile parameters:')
    print('    The optimizer of model is:', str(model.optimizer))
    print('    The initial learning rate of model optimizer is:', optimizer_initial_lr)
    print(f'    The weight decay setting of optimizer is:{weight_decay} (Tack care of optimizer type)')
    print('    The loss value of model compiling process is:', loss_values)
    print('    The metrics value of model compiling process is:', metrics_values)

    if model_type_index == 0:
        print()
        print('(3) The configuration parameters about building process of Vit model:')
        print(f'    input shape:({image_size},{image_size},1)')
        print(f'    patch size:{patch_size}')
        print(f'    num_patches:{num_patches}')
        print(f'    projection_dim:{projection_dim}')
        print(f'    num_heads:{num_heads}')
        print(f'    transformer_hidden_units:{transformer_hidden_units}')
        print(f'    transformer layers:{transformer_layers}')
        print(f'    mlp_head_units:{mlp_head_units}')
        print(f'    dropout:{dropout}')
        print(f'    emb_dropout:{emb_dropout}')
    else:
        print()
        print('(3) The configuration parameters about building process of CNN model:')
        print(f'    dropout:{dropout}')

    sys.stdout = std_out
    return history


# 这个函数提供了评估模型性能的功能，并返回字典形式的“history”
def model_evaluate(model, dataset):
    evaluate_history = model.evaluate(x=dataset, verbose=2, return_dict=True)
    return evaluate_history
